from pyspark import SparkConf, SparkContext

if __name__ == '__main__':
    conf = SparkConf().setAppName("test").setMaster("local[*]")
    sc = SparkContext(conf=conf)

    rdd = sc.parallelize([(3, 1), (3, 5), (6, 7)], 3)

    # 参数1 重新分区后有几个分区
    # 参数2 自定义分区规则，函数传入
    # 参数2 (K) -> int

    rdd2 = rdd.partitionBy(2, lambda x: x % 2)

    print(rdd2.getNumPartitions(), rdd2.glom().collect())
